{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import get_attention_mask, Attention, Transformer, PositionEmbedding, GPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch_size, seq_len, context_size\n",
    "B, L, C = [4, 10, 6]\n",
    "vocab_size = 123\n",
    "max_position = 40\n",
    "attention_head = 3\n",
    "attention_dropout = 0.1\n",
    "residual_dropout = 0.1\n",
    "layer_size = 4\n",
    "embedding_dropout = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert get_attention_mask(3).numpy().tolist() == [\n",
    "    [1.0, 0.0, 0.0],\n",
    "    [1.0, 1.0, 0.0],\n",
    "    [1.0, 1.0, 1.0]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert Attention(\n",
    "    embedding_size=C,\n",
    "    num_attention_heads=attention_head,\n",
    "    attention_dropout=attention_dropout,\n",
    "    residual_dropout=residual_dropout\n",
    ")(tf.random.uniform((B, L, C)))[0].shape == [B, L, C]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert Transformer(\n",
    "    C,\n",
    "    num_attention_heads=attention_head,\n",
    "    attention_dropout=attention_dropout,\n",
    "    residual_dropout=residual_dropout\n",
    ")(tf.random.uniform((B, L, C)))[0].shape == [B, L, C]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert PositionEmbedding(\n",
    "    max_position, C\n",
    ")(tf.random.uniform((B, L, C))).shape == [1, L, C]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert GPT(\n",
    "    vocab_size=vocab_size,\n",
    "    layer_size=layer_size,\n",
    "    block_size=max_position,\n",
    "    embedding_dropout=embedding_dropout,\n",
    "    embedding_size=C,\n",
    "    num_attention_heads=attention_head,\n",
    "    attention_dropout=attention_dropout,\n",
    "    residual_dropout=residual_dropout\n",
    ")(\n",
    "    tf.random.uniform((B, L), dtype=tf.int64, maxval=10)\n",
    ").shape == [B, L, vocab_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer = Transformer(C, attention_head, 0.1, 0.1)\n",
    "x = tf.random.uniform((B, L, C))\n",
    "y, kv = transformer(x[:, :-1, :])\n",
    "y2, kv2 = transformer(x[:, -1:, :], kv)\n",
    "y3, kv3 = transformer(x[:, :, :])\n",
    "assert y.shape == (B, L - 1, C)\n",
    "assert y2.shape == (B, 1, C)\n",
    "assert y3.shape == (B, L, C)\n",
    "assert tf.reduce_sum(tf.pow(y2[0][0] - y3[0][-1], 2)).numpy() < 1e-8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt = GPT(\n",
    "    vocab_size=vocab_size,\n",
    "    layer_size=layer_size,\n",
    "    block_size=max_position,\n",
    "    embedding_dropout=embedding_dropout,\n",
    "    embedding_size=C,\n",
    "    num_attention_heads=attention_head,\n",
    "    attention_dropout=attention_dropout,\n",
    "    residual_dropout=residual_dropout\n",
    ")\n",
    "x = tf.random.uniform((B, L, C))\n",
    "y, kv = transformer(x[:, :-1, :])\n",
    "y2, kv2 = transformer(x[:, -1:, :], kv)\n",
    "y3, kv3 = transformer(x[:, :, :])\n",
    "assert tf.reduce_sum(tf.pow(y2[0][0] - y3[0][-1], 2)).numpy() < 1e-8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
